"""
Common Imports and Visualization Functions for SWOT Internal Tide Analysis

This module provides all necessary imports and common functions for analyzing
SWOT satellite data and internal tide variance reduction using HYCOM and HRET models.
It includes visualization utilities, data processing tools, and configuration settings.

Author: B. Yadidya
Created: 2025
Purpose: Support SWOT 21-day analysis for internal tide variance reduction studies

Dependencies:
- NumPy, xarray, pandas for data manipulation
- Matplotlib, cartopy, holoviews for visualization  
- SciPy for spatial processing and signal analysis
- Dask for parallel computing
- Custom utilities from /home/yadidya/21day_swot/
"""

# Core data manipulation and analysis libraries
import numpy as np                           # Numerical computations
import xarray as xr                         # Multi-dimensional labeled arrays
import pandas as pd                         # Data analysis and manipulation
import matplotlib.pyplot as plt            # Plotting and visualization
import cmaps as cmp                        # Custom colormaps
import os                                  # Operating system interface
import datetime                           # Date and time handling
import scipy.fftpack                      # Fast Fourier Transform tools
import time                              # Time-related functions
import sys                               # System-specific parameters and functions
import glob                              # Unix-style pathname pattern expansion
import warnings                          # Warning control
import colormaps 

# Geospatial and mapping libraries
import cartopy.crs as ccrs                          # Coordinate reference systems
import cartopy.feature as cfeature                 # Geographic features
from cartopy.mpl.geoaxes import GeoAxes
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter  # Coordinate formatters
from mpl_toolkits.basemap import Basemap          # Alternative mapping toolkit
from mpl_toolkits.axes_grid1.inset_locator import inset_axes  # Subplot positioning
from matplotlib.gridspec import GridSpec

# Color and styling libraries
import matplotlib.colors as mcolors                # Color utilities
import seaborn as sns                             # Statistical data visualization
from matplotlib.colors import ListedColormap, LinearSegmentedColormap, BoundaryNorm
from matplotlib.patches import Rectangle          # Geometric shapes for plots

# Parallel computing and distributed processing
from dask_jobqueue import SLURMCluster           # SLURM cluster interface
from dask.distributed import Client             # Distributed computing client
import dask.array as da                         # Large array processing

# Scientific analysis libraries
from astropy.timeseries import LombScargle      # Astronomical time series analysis
from scipy.signal import lombscargle           # Signal processing
from scipy.spatial import cKDTree              # Spatial data structures
import scipy.ndimage as ndimage               # N-dimensional image processing

# Interactive visualization and web-based plotting
import holoviews as hv                         # Declarative data visualization
from bokeh.models import FixedTicker          # Interactive plot widgets
import hvplot.xarray                          # xarray plotting interface
import xhistogram.xarray as xh               # Histogram computations

# Advanced data analysis
import xrft                                   # xarray FFT tools
from IPython.display import display, HTML   # Jupyter notebook display utilities

# Add custom project path to Python path for importing project-specific utilities
sys.path.append('/')
from utils import *  # Import custom utility functions for SWOT analysis

# Configure warning filters to suppress routine warnings during analysis
warnings.simplefilter("ignore", category=RuntimeWarning)
warnings.simplefilter("ignore", category=FutureWarning)

# Initialize holoviews extension for interactive plotting
hv.extension('bokeh')

# Fonts
import matplotlib.font_manager as fm
# UPDATE THIS PATH to the actual folder where you put the .OTF files
font_folder = 'fonts' 

# Define paths to the specific files you need
# Note: Linux filenames are case-sensitive. Ensure these match your files exactly.
regular_font_path = os.path.join(font_folder, 'MYRIADPRO-REGULAR.OTF')
bold_font_path    = os.path.join(font_folder, 'MYRIADPRO-BOLD.OTF')
italic_font_path  = os.path.join(font_folder, 'MYRIADPRO-CONDIT.OTF')

fm.fontManager.addfont(regular_font_path)
fm.fontManager.addfont(bold_font_path)
fm.fontManager.addfont(italic_font_path)

bold_font_props = fm.FontProperties(fname=bold_font_path, size=9)

# Configure matplotlib default settings for consistent figure appearance
# plt.rcParams['figure.figsize'] = (10, 6)    # Default figure size in inches

plt.rcParams.update({
    'font.family': 'sans-serif',
    'font.sans-serif': ['Myriad Pro', 'DejaVu Sans', 'Arial'],
    'font.size': 10,                 
    'axes.labelsize': 9,            
    'axes.titlesize': 10,            
    'xtick.labelsize': 9,           
    'ytick.labelsize': 9,           
    'legend.fontsize': 10,
    'figure.titlesize': 10,
    'axes.linewidth': 0.5,          
    'xtick.major.width': 0.5,       
    'ytick.major.width': 0.5,
    'grid.linewidth': 0.5,
    'lines.linewidth': 1.0,      
    'figure.dpi': 500
})

def plot_map_with_bbox(lon_min, lon_max, lat_min, lat_max, edgecolor='red', linewidth=1.5):
    """
    Creates a global map with a highlighted bounding box region.
    
    Useful for visualizing study areas, SWOT pass coverage regions, or any 
    rectangular geographic area of interest on a global context map.
    
    Parameters
    ----------
    lon_min : float
        Minimum longitude of bounding box (degrees East, 0-360)
    lon_max : float  
        Maximum longitude of bounding box (degrees East, 0-360)
    lat_min : float
        Minimum latitude of bounding box (degrees North, -90 to 90)
    lat_max : float
        Maximum latitude of bounding box (degrees North, -90 to 90)
    edgecolor : str, optional
        Color of the bounding box border (default: 'red')
    linewidth : float, optional
        Width of the bounding box border in points (default: 1.5)
        
    Returns
    -------
    None
        Displays the plot directly
        
    Notes
    -----
    - Automatically normalizes longitude values to 0-360° range
    - Uses the plot_global_map_on_axes_basemap function for the base map
    - Rectangle coordinates are in data space (degrees)
    - Figure size is set to (10, 5) for good aspect ratio
    
    Examples
    --------
    >>> # Highlight the Hawaiian region
    >>> plot_map_with_bbox(199.5, 205.5, 18, 23, edgecolor='blue')
    
    >>> # Show a Pacific study area
    >>> plot_map_with_bbox(180, 220, -10, 10)
    """
    # Create figure with appropriate size for global map visualization
    fig, ax = plt.subplots(figsize=(10, 5))
    
    # Add base map with coastlines and grid
    plot_global_map_on_axes_basemap(ax)
    
    # Set latitude limits to match the basemap coverage
    ax.set_ylim(-70, 70)

    # Normalize longitude values to 0-360° range to handle crossing dateline
    lon_min %= 360
    lon_max %= 360

    # Calculate bounding box dimensions
    width = lon_max - lon_min    # Longitude span in degrees
    height = lat_max - lat_min   # Latitude span in degrees

    # Create a Rectangle patch for the bounding box
    rect = Rectangle(
        xy=(lon_min, lat_min),  # Bottom-left corner coordinates
        width=width,            # Width in degrees longitude
        height=height,          # Height in degrees latitude
        edgecolor=edgecolor,    # Border color
        linewidth=linewidth,    # Border thickness
        facecolor='none',       # Transparent fill (outline only)
        transform=ax.transData  # Use data coordinates (degrees)
    )

    # Add the bounding box rectangle to the map
    ax.add_patch(rect)

    # Display the plot immediately
    plt.show()


def get_swot_passes_in_bbox(lats, lons, passes, min_lat, max_lat, min_lon, max_lon):
    """
    Extract SWOT satellite pass numbers that intersect a specified geographic bounding box.
    
    This function is essential for identifying which SWOT orbital passes contain data
    within a specific study region, enabling focused analysis of regional internal tide
    patterns and variance reduction performance.
    
    Parameters
    ----------
    lats : np.ndarray or xarray.DataArray
        3D array of latitudes with shape (pass_num, num_lines, num_pixels)
        Contains latitude coordinates for each pixel in each SWOT pass
    lons : np.ndarray or xarray.DataArray  
        3D array of longitudes with shape (pass_num, num_lines, num_pixels)
        Contains longitude coordinates in [0, 360] degrees East format
    passes : np.ndarray or xarray.DataArray
        1D array of pass numbers with shape (pass_num,)
        Sequential pass identifiers for the SWOT mission
    min_lat : float
        Minimum latitude of bounding box (degrees North, -90 to 90)
    max_lat : float
        Maximum latitude of bounding box (degrees North, -90 to 90)
    min_lon : float
        Minimum longitude of bounding box (degrees East, 0 to 360)
    max_lon : float
        Maximum longitude of bounding box (degrees East, 0 to 360)
    
    Returns
    -------
    np.ndarray
        Sorted array of unique pass numbers where satellite swath intersects 
        the specified bounding box
    
    Raises
    ------
    ValueError
        If input arrays have incompatible shapes, incorrect dimensions,
        or bounding box coordinates are invalid
        
    Notes
    -----
    - Longitude coordinates must be in [0, 360] degrees East format
    - Function performs comprehensive input validation
    - For regions crossing the dateline, ensure proper longitude formatting
    - Example: Hawaii region should use ~199.5 to 205.5 degrees East
    
    Examples
    --------
    >>> # Find passes over the Hawaiian Islands
    >>> hawaii_passes = get_swot_passes_in_bbox(
    ...     lats, lons, passes, 
    ...     min_lat=18, max_lat=23, 
    ...     min_lon=199.5, max_lon=205.5
    ... )
    
    >>> # Identify passes over the Gulf Stream region
    >>> gulf_passes = get_swot_passes_in_bbox(
    ...     lats, lons, passes,
    ...     min_lat=35, max_lat=45,
    ...     min_lon=285, max_lon=305  # 75W to 55W converted to degrees East
    ... )
    """
    # Comprehensive input validation to ensure data integrity
    
    # Check array shape compatibility
    if lats.shape != lons.shape:
        raise ValueError(f"Latitude shape {lats.shape} does not match longitude shape {lons.shape}.")
    
    # Verify arrays are 3D as expected for SWOT pass data
    if lats.ndim != 3 or lons.ndim != 3:
        raise ValueError(f"Expected 3D latitude/longitude arrays, got {lats.ndim}D and {lons.ndim}D.")
    
    # Check passes array is 1D
    if passes.ndim != 1:
        raise ValueError(f"Expected 1D passes array, got {passes.ndim}D.")
    
    # Ensure first dimension consistency across all arrays
    if lats.shape[0] != passes.shape[0]:
        raise ValueError(f"Pass array size {passes.shape[0]} does not match first dimension of lat/lon {lats.shape[0]}.")
    
    # Validate bounding box coordinate logic
    if min_lat >= max_lat:
        raise ValueError("min_lat must be less than max_lat.")
    if min_lon >= max_lon:
        raise ValueError("min_lon must be less than max_lon.")
    
    # Warn about longitude range if outside expected bounds
    if min_lon < 0 or max_lon < 0 or min_lon > 360 or max_lon > 360:
        print("Warning: Longitudes outside [0, 360] range. For Hawaii, use ~199.5 to 205.5 (e.g., 360 - 160.5 to 360 - 154.5).")
    
    # Verify longitude data is within expected range
    if (lons < 0).any() or (lons > 360).any():
        raise ValueError("Longitude array contains values outside [0, 360].")

    # Initialize list to collect valid pass numbers
    valid_passes = []

    # Iterate through each SWOT pass to check for intersection with bounding box
    for i, pass_num in enumerate(passes):
        # Extract coordinate arrays for current pass
        # Handle both numpy arrays and xarray DataArrays
        pass_lats = lats[i].values if hasattr(lats[i], 'values') else lats[i]  # Shape: (num_lines, num_pixels)
        pass_lons = lons[i].values if hasattr(lons[i], 'values') else lons[i]  # Shape: (num_lines, num_pixels)
        
        # Create boolean mask for pixels within the specified bounding box
        mask = (pass_lats >= min_lat) & (pass_lats <= max_lat) & \
               (pass_lons >= min_lon) & (pass_lons <= max_lon)
        
        # If any pixels fall within the bounding box, include this pass
        if mask.any():
            valid_passes.append(int(pass_num))

    # Convert to numpy array, remove duplicates, and sort for consistent output
    pass_numbers = np.unique(valid_passes)
    
    return pass_numbers


def plot_global_map_on_axes_basemap(ax, label_style=1):
    """
    Enhances a given Axes object with Basemap features, including configurable latitude and longitude labels.
    
    Parameters:
    - ax: The Axes object to enhance.
    - label_style: int
        1: Parallels left, Meridians bottom (default)
        2: Parallels left, Meridians none
        3: Parallels none, Meridians bottom
        4: No labels
    """
    m = Basemap(
        projection='cyl', lon_0=180,
        llcrnrlat=-63, urcrnrlat=63, llcrnrlon=20, urcrnrlon=380, ax=ax)
    
    m.drawcoastlines(color='grey')
    m.fillcontinents(color='grey')

    # Determine which labels to turn on
    if label_style == 1:         # Left & Bottom
        parallel_labels  = [1, 0, 0, 0]
        meridian_labels  = [0, 0, 0, 1]
    elif label_style == 2:       # Left only
        parallel_labels  = [1, 0, 0, 0]
        meridian_labels  = [0, 0, 0, 0]
    elif label_style == 3:       # Bottom only
        parallel_labels  = [0, 0, 0, 0]
        meridian_labels  = [0, 0, 0, 1]
    elif label_style == 4:       # None
        parallel_labels  = [0, 0, 0, 0]
        meridian_labels  = [0, 0, 0, 0]
    else:
        raise ValueError("label_style must be 1, 2, 3, or 4.")

    m.drawparallels(
        range(-60, 61, 30),
        # labels=parallel_labels,
        linewidth=0.5,
        color='lightgray',
        dashes=[2, 2],
        fontsize=6
    )
    m.drawmeridians(
        range(0, 361, 60),
        # labels=meridian_labels,
        linewidth=0.5,
        color='lightgray',
        dashes=[2, 2],
        fontsize=6
    )
    
def shift_and_sort_longitude(data_array):
    """
    Adjusts the longitude of a DataArray to range from 20 to 380 degrees.

    Parameters:
    - data_array: xarray.DataArray with a 'longitude' coordinate.

    Returns:
    - A new xarray.DataArray with adjusted and sorted 'longitude'.
    """
    # Create a new longitude array
    shifted_longitude = data_array.longitude % 360

    # Identify and shift segments beyond 360 if wrapping is needed
    adjusted_data = data_array.assign_coords(
        longitude=("longitude", np.where(shifted_longitude < 20, shifted_longitude + 360, shifted_longitude))
    )

    # Sort by longitude to ensure proper order
    sorted_data = adjusted_data.sortby('longitude')

    return sorted_data